{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training a part-of-speech tagger with transformers (BERT)\n",
    "\n",
    "This example shows how to use Thinc and Hugging Face's [`transformers`](https://github.com/huggingface/transformers) library to implement and train a part-of-speech tagger on the Universal Dependencies [AnCora corpus](https://github.com/UniversalDependencies/UD_Spanish-AnCora). This notebook assumes familiarity with machine learning concepts, transformer models and Thinc's config system and `Model` API (see the \"Thinc for beginners\" notebook and the [documentation](https://thinc.ai/docs) for more info)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pip install \"thinc>=8.0.0\" transformers torch \"ml_datasets>=0.2.0\" \"tqdm>=4.41\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's use Thinc's `prefer_gpu` helper to make sure we're performing operations **on GPU if available**. The function should be called right after importing Thinc, and it returns a boolean indicating whether the GPU has been activated. If we're on GPU, we can also call `use_pytorch_for_gpu_memory` to route `cupy`'s memory allocation via PyTorch, so both can play together nicely."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import prefer_gpu, use_pytorch_for_gpu_memory\n",
    "\n",
    "is_gpu = prefer_gpu()\n",
    "print(\"GPU:\", is_gpu)\n",
    "if is_gpu:\n",
    "    use_pytorch_for_gpu_memory()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview: the final config\n",
    "\n",
    "Here's the final config for the model we're building in this notebook. It references a custom `TransformersTagger` that takes the name of a starter (the pretrained model to use), an optimizer, a learning rate schedule with warm-up and the general training settings. You can keep the config string within your file or notebook, or save it to a `conig.cfg` file and load it in via `Config.from_disk`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONFIG = \"\"\"\n",
    "[model]\n",
    "@layers = \"TransformersTagger.v1\"\n",
    "starter = \"bert-base-multilingual-cased\"\n",
    "\n",
    "[optimizer]\n",
    "@optimizers = \"Adam.v1\"\n",
    "\n",
    "[optimizer.learn_rate]\n",
    "@schedules = \"warmup_linear.v1\"\n",
    "initial_rate = 0.01\n",
    "warmup_steps = 3000\n",
    "total_steps = 6000\n",
    "\n",
    "[loss]\n",
    "@losses = \"SequenceCategoricalCrossentropy.v1\"\n",
    "\n",
    "[training]\n",
    "batch_size = 128\n",
    "words_per_subbatch = 2000\n",
    "n_epoch = 10\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Defining the model\n",
    "\n",
    "The Thinc model we want to define should consist of 3 components: the transformers **tokenizer**, the actual **transformer** implemented in PyTorch and a **softmax-activated output layer**.\n",
    "\n",
    "\n",
    "### 1. Wrapping the tokenizer\n",
    "\n",
    "To make it easier to keep track of the data that's passed around (and get type errors if something goes wrong), we first create a `TokensPlus` dataclass that holds the information we need from the `transformers` tokenizer. The most important work we'll do in this class is to build an _alignment map_. The transformer models are trained on input sequences that over-segment the sentence, so that they can work on smaller vocabularies. These over-segmentations are generally called \"word pieces\". The transformer will return a tensor with one vector per wordpiece. We need to map that to a tensor with one vector per POS-tagged token. We'll pass those token representations into a feed-forward network to predict the tag probabilities. During the backward pass, we'll then need to invert this mapping, so that we can calculate the gradients with respect to the wordpieces given the gradients with respect to the tokens. To keep things relatively simple, we'll store the alignment as a list of arrays, with each array mapping one token to one wordpiece vector (its first one). To make this work, we'll need to run the tokenizer with `is_split_into_words=True`, which should ensure that we get at least one wordpiece per token."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, List\n",
    "import numpy\n",
    "from thinc.types import Ints1d, Floats2d\n",
    "from dataclasses import dataclass\n",
    "import torch\n",
    "from transformers import BatchEncoding, TokenSpan\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TokensPlus:\n",
    "    batch_size: int\n",
    "    tok2wp: List[Ints1d]\n",
    "    input_ids: torch.Tensor\n",
    "    token_type_ids: torch.Tensor\n",
    "    attention_mask: torch.Tensor\n",
    "        \n",
    "    def __init__(self, inputs: List[List[str]], wordpieces: BatchEncoding):\n",
    "        self.input_ids = wordpieces[\"input_ids\"]\n",
    "        self.attention_mask = wordpieces[\"attention_mask\"]\n",
    "        self.token_type_ids = wordpieces[\"token_type_ids\"]\n",
    "        self.batch_size = self.input_ids.shape[0]\n",
    "        self.tok2wp = []\n",
    "        for i in range(self.batch_size):\n",
    "            spans = [wordpieces.word_to_tokens(i, j) for j in range(len(inputs[i]))]\n",
    "            self.tok2wp.append(self.get_wp_starts(spans))\n",
    "        \n",
    "    def get_wp_starts(self, spans: List[Optional[TokenSpan]]) -> Ints1d:\n",
    "        \"\"\"Calculate an alignment mapping each token index to its first wordpiece.\"\"\"\n",
    "        alignment = numpy.zeros((len(spans)), dtype=\"i\")\n",
    "        for i, span in enumerate(spans):\n",
    "            if span is None:\n",
    "                raise ValueError(\n",
    "                    \"Token did not align to any wordpieces. Was the tokenizer \"\n",
    "                    \"run with is_split_into_words=True?\"\n",
    "                )\n",
    "            else:\n",
    "                alignment[i] = span.start\n",
    "        return alignment\n",
    "    \n",
    "\n",
    "def test_tokens_plus(name: str=\"bert-base-multilingual-cased\"):\n",
    "    from transformers import AutoTokenizer\n",
    "    inputs = [\n",
    "        [\"Our\", \"band\", \"is\", \"called\", \"worlthatmustbedivided\", \"!\"],\n",
    "        [\"We\", \"rock\", \"!\"]\n",
    "    ]\n",
    "    tokenizer = AutoTokenizer.from_pretrained(name)\n",
    "    wordpieces = tokenizer(\n",
    "        inputs,\n",
    "        is_split_into_words=True,\n",
    "        add_special_tokens=True,\n",
    "        return_token_type_ids=True,\n",
    "        return_attention_mask=True,\n",
    "        return_length=True,\n",
    "        return_tensors=\"pt\",\n",
    "        padding=\"longest\"\n",
    "    )\n",
    "    tplus = TokensPlus(inputs, wordpieces)\n",
    "    assert len(tplus.tok2wp) == len(inputs) == len(tplus.input_ids)\n",
    "    for i, align in enumerate(tplus.tok2wp):\n",
    "        assert len(align) == len(inputs[i])\n",
    "        for j in align:\n",
    "            assert j >= 0 and j < tplus.input_ids.shape[1]\n",
    "            \n",
    "test_tokens_plus()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The wrapped tokenizer will take a list-of-lists as input (the texts) and will output a `TokensPlus` object containing the fully padded batch of tokens. The wrapped transformer will take a list of `TokensPlus` objects and will output a list of 2-dimensional arrays.\n",
    "\n",
    "1. **TransformersTokenizer**: `List[List[str]]` → `TokensPlus`\n",
    "2. **Transformer**: `TokensPlus` → `List[Array2d]`\n",
    "\n",
    "> 💡 Since we're adding type hints everywhere (and Thinc is fully typed, too), you can run your code through [`mypy`](https://mypy.readthedocs.io/en/stable/) to find type errors and inconsistencies. If you're using an editor like Visual Studio Code, you can enable `mypy` linting and type errors will be highlighted in real time as you write code.\n",
    "\n",
    "To use the tokenizer as a layer in our network, we register a new function that returns a Thinc `Model`. The function takes the name of the pretrained weights (e.g. `\"bert-base-multilingual-cased\"`) as an argument that can later be provided via the config. After loading the `AutoTokenizer`, we can stash it in the attributes. This lets us access it at any point later on via `model.attrs[\"tokenizer\"]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import thinc\n",
    "from thinc.api import Model\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "@thinc.registry.layers(\"transformers_tokenizer.v1\")\n",
    "def TransformersTokenizer(name: str) -> Model[List[List[str]], TokensPlus]:\n",
    "    def forward(model, inputs: List[List[str]], is_train: bool):\n",
    "        tokenizer = model.attrs[\"tokenizer\"]\n",
    "        wordpieces = tokenizer(\n",
    "            inputs,\n",
    "            is_split_into_words=True,\n",
    "            add_special_tokens=True,\n",
    "            return_token_type_ids=True,\n",
    "            return_attention_mask=True,\n",
    "            return_length=True,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=\"longest\"\n",
    "        )\n",
    "        return TokensPlus(inputs, wordpieces), lambda d_tokens: []\n",
    "\n",
    "    return Model(\"tokenizer\", forward, attrs={\"tokenizer\": AutoTokenizer.from_pretrained(name)})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The forward pass takes the model and a list-of-lists of strings and outputs the `TokensPlus` dataclass. It also outputs a dummy callback function, to meet the API contract for Thinc models. Even though there's no way we can meaningfully \"backpropagate\" this layer, we need to make sure the function has the right signature, so that it can be used interchangeably with other layers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Wrapping the transformer\n",
    "\n",
    "To load and wrap the transformer, we can use `transformers.AutoModel` and Thinc's `PyTorchWrapper`. The forward method of the wrapped model can take arbitrary positional arguments and keyword arguments. Here's what the wrapped model is going to look like:\n",
    "\n",
    "```python\n",
    "@thinc.registry.layers(\"transformers_model.v1\")\n",
    "def Transformer(name) -> Model[TokensPlus, List[Floats2d]]:\n",
    "    return PyTorchWrapper(\n",
    "        AutoModel.from_pretrained(name),\n",
    "        convert_inputs=convert_transformer_inputs,\n",
    "        convert_outputs=convert_transformer_outputs,\n",
    "    )\n",
    "```\n",
    "\n",
    "The `Transformer` layer takes our `TokensPlus` dataclass as input and outputs a list of 2-dimensional arrays. The convert functions are used to **map inputs and outputs to and from the PyTorch model**. Each function should return the converted output, and a callback to use during the backward pass. To make the arbitrary positional and keyword arguments easier to manage, Thinc uses an `ArgsKwargs` dataclass, essentially a named tuple with `args` and `kwargs` that can be spread into a function as `*ArgsKwargs.args` and `**ArgsKwargs.kwargs`. The `ArgsKwargs` objects will be passed straight into the model in the forward pass, and straight into `torch.autograd.backward` during the backward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Tuple, Callable\n",
    "from thinc.api import ArgsKwargs, torch2xp, xp2torch\n",
    "from thinc.types import Floats2d\n",
    "\n",
    "def convert_transformer_inputs(model, tokens: TokensPlus, is_train):\n",
    "    kwargs = {\n",
    "        \"input_ids\": tokens.input_ids,\n",
    "        \"attention_mask\": tokens.attention_mask,\n",
    "        \"token_type_ids\": tokens.token_type_ids,\n",
    "    }\n",
    "    return ArgsKwargs(args=(), kwargs=kwargs), lambda dX: []\n",
    "\n",
    "\n",
    "def convert_transformer_outputs(\n",
    "    model: Model,\n",
    "    inputs_outputs: Tuple[TokensPlus, Tuple[torch.Tensor]],\n",
    "    is_train: bool\n",
    ") -> Tuple[List[Floats2d], Callable]:\n",
    "    tplus, trf_outputs = inputs_outputs\n",
    "    wp_vectors = torch2xp(trf_outputs[0])\n",
    "    tokvecs = [wp_vectors[i, idx] for i, idx in enumerate(tplus.tok2wp)]\n",
    "\n",
    "    def backprop(d_tokvecs: List[Floats2d]) -> ArgsKwargs:\n",
    "        # Restore entries for BOS and EOS markers\n",
    "        d_wp_vectors = model.ops.alloc3f(*trf_outputs[0].shape, dtype=\"f\")\n",
    "        for i, idx in enumerate(tplus.tok2wp):\n",
    "            d_wp_vectors[i, idx] += d_tokvecs[i]\n",
    "        return ArgsKwargs(\n",
    "            args=(trf_outputs[0],),\n",
    "            kwargs={\"grad_tensors\": xp2torch(d_wp_vectors)},\n",
    "        )\n",
    "\n",
    "    return tokvecs, backprop"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The input and output transformation functions give you full control of how data is passed into and out of the underlying PyTorch model, so you can work with PyTorch layers that expect and return arbitrary objects. Putting it all together, we now have a nice layer that is configured with the name of a transformer model, that acts as a function mapping tokenized input into feature vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import thinc\n",
    "from thinc.api import PyTorchWrapper\n",
    "from transformers import AutoModel\n",
    "\n",
    "@thinc.registry.layers(\"transformers_model.v1\")\n",
    "def Transformer(name: str) -> Model[TokensPlus, List[Floats2d]]:\n",
    "    return PyTorchWrapper(\n",
    "        AutoModel.from_pretrained(name),\n",
    "        convert_inputs=convert_transformer_inputs,\n",
    "        convert_outputs=convert_transformer_outputs,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now combine the `TransformersTokenizer` and `Transformer` into a feed-forward network using the `chain` combinator. The `with_array` layer transforms a sequence of data into a contiguous 2d array on the way into and\n",
    "out of a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import chain, with_array, Softmax\n",
    "\n",
    "@thinc.registry.layers(\"TransformersTagger.v1\")\n",
    "def TransformersTagger(starter: str, n_tags: int = 17) -> Model[List[List[str]], List[Floats2d]]:\n",
    "    return chain(\n",
    "        TransformersTokenizer(starter),\n",
    "        Transformer(starter),\n",
    "        with_array(Softmax(n_tags)),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Training the model\n",
    "\n",
    "### Setting up model and data\n",
    "\n",
    "Since we've registered all layers via `@thinc.registry.layers`, we can construct the model, its settings and other functions we need from a config (see `CONFIG` above). The result is a config object with a model, an optimizer, a function to calculate the loss and the training settings."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from thinc.api import Config, registry\n",
    "\n",
    "C = registry.resolve(Config().from_str(CONFIG))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = C[\"model\"]\n",
    "optimizer = C[\"optimizer\"]\n",
    "calculate_loss = C[\"loss\"]\n",
    "cfg = C[\"training\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We’ve prepared a separate package [`ml-datasets`](https://github.com/explosion/ml-datasets) with loaders for some common datasets, including the AnCora data. If we're using a GPU, calling `ops.asarray` on the outputs ensures that they're converted to `cupy` arrays (instead of `numpy` arrays). Calling `Model.initialize` with a batch of inputs and outputs allows Thinc to **infer the missing dimensions**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ml_datasets\n",
    "(train_X, train_Y), (dev_X, dev_Y) = ml_datasets.ud_ancora_pos_tags()\n",
    "\n",
    "train_Y = list(map(model.ops.asarray, train_Y))  # convert to cupy if needed\n",
    "dev_Y = list(map(model.ops.asarray, dev_Y))  # convert to cupy if needed\n",
    "\n",
    "model.initialize(X=train_X[:5], Y=train_Y[:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper functions for training and evaluation\n",
    "\n",
    "Before we can train the model, we also need to set up the following helper functions for batching and evaluation:\n",
    "\n",
    "* **`minibatch_by_words`:** Group pairs of sequences into minibatches under `max_words` in size, considering padding. The size of a padded batch is the length of its longest sequence multiplied by the number of elements in the batch.\n",
    "* **`evaluate_sequences`:** Evaluate the model sequences of two-dimensional arrays and return the score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def minibatch_by_words(pairs, max_words):\n",
    "    pairs = list(zip(*pairs))\n",
    "    pairs.sort(key=lambda xy: len(xy[0]), reverse=True)\n",
    "    batch = []\n",
    "    for X, Y in pairs:\n",
    "        batch.append((X, Y))\n",
    "        n_words = max(len(xy[0]) for xy in batch) * len(batch)\n",
    "        if n_words >= max_words:\n",
    "            yield batch[:-1]\n",
    "            batch = [(X, Y)]\n",
    "    if batch:\n",
    "        yield batch\n",
    "\n",
    "def evaluate_sequences(model, Xs: List[Floats2d], Ys: List[Floats2d], batch_size: int) -> float:\n",
    "    correct = 0.0\n",
    "    total = 0.0\n",
    "    for X, Y in model.ops.multibatch(batch_size, Xs, Ys):\n",
    "        Yh = model.predict(X)\n",
    "        for yh, y in zip(Yh, Y):\n",
    "            correct += (y.argmax(axis=1) == yh.argmax(axis=1)).sum()\n",
    "            total += y.shape[0]\n",
    "    return float(correct / total)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The training loop\n",
    "\n",
    "Transformers often learn best with **large batch sizes** – larger than fits in GPU memory. But you don't have to backprop the whole batch at once. Here we consider the \"logical\" batch size (number of examples per update) separately from the physical batch size. For the physical batch size, what we care about is the **number of words** (considering padding too). We also want to sort by length, for efficiency. \n",
    "\n",
    "At the end of the batch, we **call the optimizer** with the accumulated gradients, and **advance the learning rate schedules**. You might want to evaluate more often than once per epoch – that's up to you."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "from thinc.api import fix_random_seed\n",
    "\n",
    "fix_random_seed(0)\n",
    "\n",
    "for epoch in range(cfg[\"n_epoch\"]):\n",
    "    batches = model.ops.multibatch(cfg[\"batch_size\"], train_X, train_Y, shuffle=True)\n",
    "    for outer_batch in tqdm(batches, leave=False):\n",
    "        for batch in minibatch_by_words(outer_batch, cfg[\"words_per_subbatch\"]):\n",
    "            inputs, truths = zip(*batch)\n",
    "            inputs = list(inputs)\n",
    "            guesses, backprop = model(inputs, is_train=True)\n",
    "            backprop(calculate_loss.get_grad(guesses, truths))\n",
    "        model.finish_update(optimizer)\n",
    "        optimizer.step_schedules()\n",
    "    score = evaluate_sequences(model, dev_X, dev_Y, cfg[\"batch_size\"])\n",
    "    print(epoch, f\"{score:.3f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you like, you can call `model.to_disk` or `model.to_bytes` to save the model weights to a directory or a bytestring."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
